#!/usr/bin/env python3
"""
Phase 1: Data Assembly for Causal Identification Paper

Builds the causal identification panel from existing project data
plus additional downloads. Produces:

1. causal_panel.csv — Main estimation panel (140 countries, 1992-2024)
2. lagged_instruments.csv — Lagged fertility/Z instruments (20, 25, 30yr)
3. bartik_instrument.csv — Demographic shift-share instruments
4. treatment_cohorts.csv — Capital account opening episodes and cohort assignments
5. bop_components.csv — Balance of payments sub-components
6. scm_covariates.csv — Additional covariates for synthetic control matching

Data sources:
- Existing: UN WPP demographics, IMF WEO, KAOPEN, EWN, WDI, PWT
- New downloads: BOP components, education, institutional quality (V-Dem proxy via WGI)
"""

import sys
import os
import pandas as pd
import numpy as np
from pathlib import Path

# Project paths
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
CAUSAL_DIR = PROJECT_DIR / "causal_identification"
RAW_DIR = CAUSAL_DIR / "data" / "raw"
PROCESSED_DIR = CAUSAL_DIR / "data" / "processed"
RAW_DIR.mkdir(parents=True, exist_ok=True)
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

# Add project to path for imports
sys.path.insert(0, str(MULTILATERAL_DIR))
sys.path.insert(1, str(MULTILATERAL_DIR / "followup"))

# CCA countries (13 post-Soviet + Mongolia)
CCA_COUNTRIES = [
    'ARM', 'AZE', 'BLR', 'GEO', 'KAZ', 'KGZ', 'MDA',
    'MNG', 'RUS', 'TJK', 'TKM', 'UKR', 'UZB'
]

# Baltic states (useful comparison — opened earliest, now EU)
BALTIC_COUNTRIES = ['EST', 'LVA', 'LTU']

# All transition economies (CCA + Baltics + CEE)
CEE_COUNTRIES = [
    'ALB', 'BGR', 'BIH', 'HRV', 'CZE', 'HUN', 'MKD', 'MNE',
    'POL', 'ROU', 'SRB', 'SVK', 'SVN'
]
ALL_TRANSITION = CCA_COUNTRIES + BALTIC_COUNTRIES + CEE_COUNTRIES


# =====================================================================
# STEP 1: Load existing processed data
# =====================================================================

def load_existing_data():
    """Load the followup panel (140-country) and demographic shares."""
    print("=" * 70)
    print("STEP 1: Loading existing project data")
    print("=" * 70)

    # Load followup panel (has remittances merged)
    followup_panel = MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv"
    main_panel = MULTILATERAL_DIR / "data" / "processed" / "full_panel.csv"

    if followup_panel.exists():
        panel = pd.read_csv(followup_panel, low_memory=False)
        print(f"  Loaded followup panel: {panel.shape}")
    else:
        panel = pd.read_csv(main_panel, low_memory=False)
        print(f"  Loaded main panel: {panel.shape}")

    # Load demographic shares (1950-2101, all countries)
    demo_shares = pd.read_csv(
        MULTILATERAL_DIR / "data" / "processed" / "demographic_shares.csv",
        low_memory=False
    )
    print(f"  Loaded demographic shares: {demo_shares.shape}")
    print(f"    Countries: {demo_shares['iso3'].nunique()}, "
          f"Years: {demo_shares['year'].min()}-{demo_shares['year'].max()}")

    # Load demographic polynomials
    demo_poly = pd.read_csv(
        MULTILATERAL_DIR / "data" / "processed" / "demographic_polynomials.csv",
        low_memory=False
    )
    print(f"  Loaded demographic polynomials: {demo_poly.shape}")

    return panel, demo_shares, demo_poly


# =====================================================================
# STEP 2: Construct lagged fertility instruments
# =====================================================================

def construct_lagged_instruments(demo_poly, lags=[20, 25, 30]):
    """
    Construct lagged demographic instruments for IV estimation.

    For each lag L, creates Z_1_{L}yr_lag, Z_2_{L}yr_lag, Z_3_{L}yr_lag
    which are the Z polynomial values from L years prior.

    Logic: births L years ago determine current age-L cohort sizes.
    These are predetermined (exogenous to current CA shocks).
    """
    print("\n" + "=" * 70)
    print("STEP 2: Constructing lagged fertility instruments")
    print("=" * 70)

    # Start with current Z values
    instruments = demo_poly[['iso3', 'year', 'Z_1', 'Z_2', 'Z_3']].copy()

    for lag in lags:
        print(f"\n  Constructing {lag}-year lagged Z instruments...")

        # Create lagged Z by shifting year
        lagged = demo_poly[['iso3', 'year', 'Z_1', 'Z_2', 'Z_3']].copy()
        lagged['year'] = lagged['year'] + lag  # shift forward so merge gives lagged values

        lagged = lagged.rename(columns={
            'Z_1': f'Z_1_lag{lag}',
            'Z_2': f'Z_2_lag{lag}',
            'Z_3': f'Z_3_lag{lag}'
        })

        instruments = instruments.merge(lagged, on=['iso3', 'year'], how='left')

        # Report coverage
        for z in ['Z_1', 'Z_2', 'Z_3']:
            col = f'{z}_lag{lag}'
            non_null = instruments[col].notna().sum()
            total = len(instruments)
            print(f"    {col}: {non_null}/{total} non-null "
                  f"({100*non_null/total:.1f}%)")

    # Also construct lagged raw age shares as alternative instruments
    # (16 instruments for 3 endogenous — better for overidentification tests)
    print("\n  Constructing 25-year lagged raw age shares...")
    share_cols = [f'd_n_{g}' for g in range(1, 18)]
    available_share_cols = [c for c in share_cols if c in demo_poly.columns]

    if available_share_cols:
        lagged_shares = demo_poly[['iso3', 'year'] + available_share_cols].copy()
        lagged_shares['year'] = lagged_shares['year'] + 25
        lagged_shares = lagged_shares.rename(
            columns={c: f'{c}_lag25' for c in available_share_cols}
        )
        instruments = instruments.merge(lagged_shares, on=['iso3', 'year'], how='left')
        print(f"    Added {len(available_share_cols)} lagged share columns")
    else:
        print("    WARNING: demeaned shares not in polynomial file, skipping")

    # Save
    instruments.to_csv(PROCESSED_DIR / "lagged_instruments.csv", index=False)
    print(f"\n  Saved lagged instruments: {instruments.shape}")

    # Diagnostics: correlation between current and lagged Z
    print("\n  Correlation diagnostics (current vs lagged Z):")
    est_window = instruments[(instruments['year'] >= 1992) & (instruments['year'] <= 2024)]
    for lag in lags:
        for z in ['Z_1', 'Z_2', 'Z_3']:
            corr = est_window[[z, f'{z}_lag{lag}']].dropna().corr().iloc[0, 1]
            print(f"    corr({z}, {z}_lag{lag}) = {corr:.4f}")

    return instruments


# =====================================================================
# STEP 3: Identify treatment cohorts for staggered DiD
# =====================================================================

def identify_treatment_cohorts(panel):
    """
    Identify capital account opening episodes and assign treatment cohorts.

    Treatment definition: first year a country's KAOPEN crosses 0 from below,
    OR first year KAOPEN jumps by more than 1 point.

    Returns DataFrame with:
    - iso3, opening_year, opening_type, cohort_group
    - Pre-treatment KAOPEN level, post-treatment KAOPEN level
    """
    print("\n" + "=" * 70)
    print("STEP 3: Identifying treatment cohorts")
    print("=" * 70)

    # Focus on actual data period
    ka = panel[
        (panel['kaopen'].notna()) &
        (panel['year'] >= 1970) &
        (panel['year'] <= 2024)
    ][['iso3', 'year', 'kaopen']].copy().sort_values(['iso3', 'year'])

    results = []

    for iso3 in ka['iso3'].unique():
        cdf = ka[ka['iso3'] == iso3].sort_values('year')
        if len(cdf) < 3:
            continue

        kaopen_vals = cdf['kaopen'].values
        years = cdf['year'].values

        # Method 1: First year KAOPEN crosses 0 from below
        crossed_zero = None
        for i in range(1, len(kaopen_vals)):
            if kaopen_vals[i-1] < 0 and kaopen_vals[i] >= 0:
                crossed_zero = years[i]
                break

        # Method 2: First year KAOPEN jumps > 1 point
        big_jump = None
        for i in range(1, len(kaopen_vals)):
            if kaopen_vals[i] - kaopen_vals[i-1] > 1.0:
                big_jump = years[i]
                break

        # Use the earliest event
        opening_year = None
        opening_type = None
        if crossed_zero and big_jump:
            if crossed_zero <= big_jump:
                opening_year = crossed_zero
                opening_type = 'crossed_zero'
            else:
                opening_year = big_jump
                opening_type = 'big_jump'
        elif crossed_zero:
            opening_year = crossed_zero
            opening_type = 'crossed_zero'
        elif big_jump:
            opening_year = big_jump
            opening_type = 'big_jump'

        # Classify: always open, opened, or never opened
        mean_kaopen = kaopen_vals.mean()
        min_kaopen = kaopen_vals.min()
        max_kaopen = kaopen_vals.max()
        first_kaopen = kaopen_vals[0]

        if first_kaopen >= 0 and min_kaopen >= -0.5:
            status = 'always_open'
        elif opening_year:
            status = 'opener'
        else:
            status = 'never_opened'

        # Pre/post KAOPEN averages
        if opening_year:
            pre_mask = cdf['year'] < opening_year
            post_mask = cdf['year'] >= opening_year
            pre_kaopen = cdf.loc[pre_mask, 'kaopen'].mean() if pre_mask.any() else np.nan
            post_kaopen = cdf.loc[post_mask, 'kaopen'].mean() if post_mask.any() else np.nan
        else:
            pre_kaopen = mean_kaopen
            post_kaopen = mean_kaopen

        results.append({
            'iso3': iso3,
            'opening_year': opening_year,
            'opening_type': opening_type,
            'status': status,
            'first_kaopen': first_kaopen,
            'mean_kaopen': mean_kaopen,
            'min_kaopen': min_kaopen,
            'max_kaopen': max_kaopen,
            'pre_opening_kaopen': pre_kaopen,
            'post_opening_kaopen': post_kaopen,
            'first_year_observed': years[0],
            'last_year_observed': years[-1],
            'n_years': len(years),
            'is_cca': iso3 in CCA_COUNTRIES,
            'is_baltic': iso3 in BALTIC_COUNTRIES,
            'is_cee': iso3 in CEE_COUNTRIES,
            'is_transition': iso3 in ALL_TRANSITION
        })

    cohorts = pd.DataFrame(results)

    # Assign cohort groups (5-year bins for staggered DiD)
    openers = cohorts[cohorts['status'] == 'opener'].copy()
    if len(openers) > 0:
        bins = [1970, 1990, 1995, 2000, 2005, 2010, 2015, 2025]
        labels = ['pre1990', '1990-94', '1995-99', '2000-04', '2005-09', '2010-14', '2015+']
        openers['cohort_group'] = pd.cut(
            openers['opening_year'], bins=bins, labels=labels, right=False
        )
        cohorts = cohorts.merge(
            openers[['iso3', 'cohort_group']], on='iso3', how='left'
        )
    else:
        cohorts['cohort_group'] = np.nan

    # Print summary
    print(f"\n  Treatment cohort summary:")
    print(f"    Always open:  {(cohorts['status'] == 'always_open').sum()} countries")
    print(f"    Openers:      {(cohorts['status'] == 'opener').sum()} countries")
    print(f"    Never opened: {(cohorts['status'] == 'never_opened').sum()} countries")

    print(f"\n  CCA countries:")
    cca_cohorts = cohorts[cohorts['is_cca']]
    for _, row in cca_cohorts.iterrows():
        print(f"    {row['iso3']}: status={row['status']}, "
              f"opening_year={row['opening_year']}, "
              f"KAOPEN range=[{row['min_kaopen']:.2f}, {row['max_kaopen']:.2f}]")

    print(f"\n  Baltic countries:")
    baltic_cohorts = cohorts[cohorts['is_baltic']]
    for _, row in baltic_cohorts.iterrows():
        print(f"    {row['iso3']}: status={row['status']}, "
              f"opening_year={row['opening_year']}, "
              f"KAOPEN range=[{row['min_kaopen']:.2f}, {row['max_kaopen']:.2f}]")

    if 'cohort_group' in cohorts.columns:
        print(f"\n  Opener cohort distribution:")
        opener_cohorts = cohorts[cohorts['status'] == 'opener']
        print(opener_cohorts['cohort_group'].value_counts().sort_index().to_string())

    # Save
    cohorts.to_csv(PROCESSED_DIR / "treatment_cohorts.csv", index=False)
    print(f"\n  Saved treatment cohorts: {cohorts.shape}")

    return cohorts


# =====================================================================
# STEP 4: Download additional WDI data (BOP components, education)
# =====================================================================

def download_additional_wdi():
    """
    Download BOP sub-components and education data from World Bank.

    BOP components allow mechanism decomposition:
    - Goods balance, services balance, primary income, secondary income (transfers)

    Education and governance for synthetic control matching.
    """
    print("\n" + "=" * 70)
    print("STEP 4: Downloading additional WDI data")
    print("=" * 70)

    outfile = RAW_DIR / "wdi_additional.csv"
    if outfile.exists():
        print(f"  Additional WDI data already exists at {outfile}")
        return pd.read_csv(outfile)

    try:
        import wbgapi as wb
    except ImportError:
        print("  wbgapi not installed. Install with: pip install wbgapi")
        print("  Attempting REST API fallback...")
        return _download_wdi_rest_fallback()

    indicators = {
        # BOP components (% of GDP)
        'BN.GSR.GNFS.GD.ZS': 'goods_services_balance_gdp',   # Net trade in goods and services
        'BX.TRF.PWKR.DT.GD.ZS': 'remittances_received_gdp',  # Personal remittances received
        'BM.TRF.PWKR.CD.DT': 'remittances_paid_usd',         # Personal remittances paid (USD)
        # Education
        'SE.XPD.TOTL.GD.ZS': 'education_exp_gdp',            # Government education expenditure % GDP
        'SE.TER.ENRR': 'tertiary_enrollment',                  # Tertiary enrollment ratio
        'SE.SEC.ENRR': 'secondary_enrollment',                  # Secondary enrollment ratio
        # Governance (WGI — proxy for institutional quality)
        'CC.EST': 'control_corruption',                         # Control of corruption estimate
        'GE.EST': 'govt_effectiveness',                         # Government effectiveness estimate
        'RQ.EST': 'regulatory_quality',                         # Regulatory quality estimate
        'RL.EST': 'rule_of_law',                                # Rule of law estimate
        # Additional macro
        'NY.GNS.ICTR.GN.ZS': 'gross_savings_gni',             # Gross savings % GNI
        'NE.GDI.TOTL.ZS': 'gross_investment_gdp',             # Gross capital formation % GDP
        'TT.PRI.MRCH.XD.WD': 'terms_of_trade',                # Net barter terms of trade index
    }

    frames = []
    for ind_code, var_name in indicators.items():
        try:
            print(f"  Fetching {var_name} ({ind_code})...")
            data = wb.data.DataFrame(ind_code, time=range(1970, 2025), labels=False)
            data = data.reset_index()
            data = data.melt(id_vars=['economy'], var_name='year', value_name=var_name)
            data['year'] = data['year'].str.replace('YR', '').astype(int)
            data = data.rename(columns={'economy': 'iso3'})
            data = data.dropna(subset=[var_name])
            frames.append(data)
            print(f"    Got {len(data)} observations")
        except Exception as e:
            print(f"    WARNING: Could not get {ind_code}: {e}")

    if frames:
        df = frames[0]
        for f in frames[1:]:
            df = df.merge(f, on=['iso3', 'year'], how='outer')
        df.to_csv(outfile, index=False)
        print(f"\n  Saved additional WDI data: {df.shape}")

        # Report CCA coverage
        cca_data = df[df['iso3'].isin(CCA_COUNTRIES)]
        print(f"\n  CCA coverage in additional WDI data:")
        for col in df.columns:
            if col in ['iso3', 'year']:
                continue
            n = cca_data[col].notna().sum()
            if n > 0:
                print(f"    {col}: {n} obs")

        return df

    print("  WARNING: No data downloaded")
    return None


def _download_wdi_rest_fallback():
    """Fallback WDI download via REST API (slower but no dependency)."""
    import requests

    indicators = {
        'BN.GSR.GNFS.GD.ZS': 'goods_services_balance_gdp',
        'BX.TRF.PWKR.DT.GD.ZS': 'remittances_received_gdp',
        'SE.XPD.TOTL.GD.ZS': 'education_exp_gdp',
        'SE.TER.ENRR': 'tertiary_enrollment',
        'CC.EST': 'control_corruption',
        'GE.EST': 'govt_effectiveness',
        'RQ.EST': 'regulatory_quality',
        'RL.EST': 'rule_of_law',
        'NY.GNS.ICTR.GN.ZS': 'gross_savings_gni',
        'NE.GDI.TOTL.ZS': 'gross_investment_gdp',
    }

    all_rows = []
    for ind_code, var_name in indicators.items():
        print(f"  REST API: fetching {var_name}...")
        page = 1
        total_fetched = 0
        while True:
            url = (f"https://api.worldbank.org/v2/country/all/indicator/"
                   f"{ind_code}?date=1970:2024&format=json&per_page=10000&page={page}")
            try:
                resp = requests.get(url, timeout=60)
                if resp.status_code != 200:
                    print(f"    HTTP {resp.status_code}")
                    break
                data = resp.json()
                if len(data) < 2 or not data[1]:
                    break
                for r in data[1]:
                    if r['value'] is not None:
                        all_rows.append({
                            'iso3': r['country']['id'],
                            'year': int(r['date']),
                            var_name: float(r['value'])
                        })
                        total_fetched += 1
                # Check if more pages
                if page >= data[0].get('pages', 1):
                    break
                page += 1
            except Exception as e:
                print(f"    Error: {e}")
                break
        print(f"    Got {total_fetched} observations")

    if all_rows:
        df = pd.DataFrame(all_rows)
        # Pivot to wide by grouping
        df = df.groupby(['iso3', 'year']).first().reset_index()
        outfile = RAW_DIR / "wdi_additional.csv"
        df.to_csv(outfile, index=False)
        print(f"\n  Saved: {df.shape}")
        return df

    return None


# =====================================================================
# STEP 5: Construct Demographic Bartik Instrument
# =====================================================================

def construct_bartik_instrument(demo_shares):
    """
    Construct shift-share (Bartik) demographic instrument.

    Z_Bartik_it = Σ_k (share_ik_t0) × (Δglobal_share_k_t)

    where:
    - share_ik_t0 = country i's age-group k share at baseline (1990)
    - Δglobal_share_k_t = change in population-weighted global share of
      age group k from t0 to t

    Identification: global demographic trends (shifts) are exogenous to
    any individual country's current account. Country i's initial age
    structure (shares) captures differential exposure.
    """
    print("\n" + "=" * 70)
    print("STEP 5: Constructing Demographic Bartik Instrument")
    print("=" * 70)

    share_cols = [f'n_{g}' for g in range(1, 18)]
    t0 = 1990  # Baseline year

    # --- Step A: Compute population-weighted global age shares by year ---
    # Weight by total population
    df = demo_shares[demo_shares['year'].between(1950, 2024)].copy()

    global_shares = (
        df.groupby('year')
        .apply(lambda g: pd.Series({
            col: np.average(g[col], weights=g['total_pop'])
            for col in share_cols
        }), include_groups=False)
        .reset_index()
    )

    # Compute change from baseline
    baseline_global = global_shares[global_shares['year'] == t0][share_cols].iloc[0]
    for col in share_cols:
        global_shares[f'delta_global_{col}'] = global_shares[col] - baseline_global[col]

    print(f"  Global share changes from {t0}:")
    latest = global_shares[global_shares['year'] == 2024].iloc[0]
    for g, col in enumerate(share_cols, 1):
        delta = latest[f'delta_global_{col}']
        print(f"    {col} (age {(g-1)*5}-{g*5-1 if g<17 else '80+'}): "
              f"{delta:+.4f} ({delta*100:+.2f}pp)")

    # --- Step B: Get country baseline shares ---
    baseline_shares = df[df['year'] == t0][['iso3'] + share_cols].copy()
    baseline_shares = baseline_shares.rename(
        columns={col: f'base_{col}' for col in share_cols}
    )
    print(f"\n  Baseline shares ({t0}): {len(baseline_shares)} countries")

    # --- Step C: Construct Bartik instrument ---
    # For each country-year: Bartik = Σ_k base_share_ik × Δglobal_share_k_t
    delta_cols = [f'delta_global_{col}' for col in share_cols]

    # Merge baseline shares with all years
    bartik_df = df[['iso3', 'year']].merge(baseline_shares, on='iso3', how='left')

    # Merge global changes
    bartik_df = bartik_df.merge(
        global_shares[['year'] + delta_cols],
        on='year', how='left'
    )

    # Compute Bartik Z_1, Z_2, Z_3
    # Following the polynomial logic: Z_p = Σ_g g^p * n_g
    # Bartik version: Z_p_bartik = Σ_g g^p * (base_n_g × Δglobal_n_g)
    # Actually, the proper Bartik is just the shift-share prediction of n_g:
    # n_g_bartik_it = base_n_g_i × (1 + Δglobal_g_t / global_g_t0)
    # Then construct Z from predicted shares.
    #
    # Simpler approach: construct the shift-share predicted Z directly
    g_indices = np.arange(1, 18)

    for p in [1, 2, 3]:
        g_powers = g_indices ** p
        bartik_df[f'Z_{p}_bartik'] = sum(
            g_powers[g-1] * bartik_df[f'base_n_{g}'] * bartik_df[f'delta_global_n_{g}']
            for g in range(1, 18)
        )

    # Also construct a simpler version: just base × Δglobal summed
    bartik_df['bartik_exposure'] = sum(
        bartik_df[f'base_n_{g}'] * bartik_df[f'delta_global_n_{g}']
        for g in range(1, 18)
    )

    # Keep only the instrument columns
    result = bartik_df[['iso3', 'year',
                         'Z_1_bartik', 'Z_2_bartik', 'Z_3_bartik',
                         'bartik_exposure']].copy()

    # Diagnostics
    est = result[(result['year'] >= 1992) & (result['year'] <= 2024)].dropna()
    print(f"\n  Bartik instrument coverage: {len(est)} obs, "
          f"{est['iso3'].nunique()} countries")
    for col in ['Z_1_bartik', 'Z_2_bartik', 'Z_3_bartik', 'bartik_exposure']:
        print(f"    {col}: mean={est[col].mean():.6f}, std={est[col].std():.6f}, "
              f"range=[{est[col].min():.6f}, {est[col].max():.6f}]")

    result.to_csv(PROCESSED_DIR / "bartik_instrument.csv", index=False)
    print(f"\n  Saved Bartik instruments: {result.shape}")

    return result


# =====================================================================
# STEP 6: Build integrated causal panel
# =====================================================================

def build_causal_panel(panel, instruments, bartik, cohorts, wdi_additional):
    """
    Merge all data sources into the estimation panel.

    Filters to 1992-2024 (post-Soviet independence, pre-projection).
    """
    print("\n" + "=" * 70)
    print("STEP 6: Building integrated causal panel")
    print("=" * 70)

    # Start with the main panel, filtered to estimation window
    df = panel[
        (panel['year'] >= 1986) &  # Keep a few pre-treatment years
        (panel['year'] <= 2024) &
        (panel['ca_gdp'].notna())
    ].copy()

    # Drop duplicates (some countries like MNG have multiple entries per year)
    pre_dedup = len(df)
    df = df.drop_duplicates(subset=['iso3', 'year'], keep='first')
    if len(df) < pre_dedup:
        print(f"  Dropped {pre_dedup - len(df)} duplicate iso3-year rows")

    print(f"  Base panel (1986-2024, non-null CA): {len(df)} obs, "
          f"{df['iso3'].nunique()} countries")

    # Merge lagged instruments
    if instruments is not None:
        inst_cols = [c for c in instruments.columns if c not in ['iso3', 'year']]
        # Drop current Z columns that duplicate panel
        inst_cols = [c for c in inst_cols if c not in ['Z_1', 'Z_2', 'Z_3']]
        df = df.merge(
            instruments[['iso3', 'year'] + inst_cols],
            on=['iso3', 'year'], how='left'
        )
        print(f"  Merged lagged instruments: {len(inst_cols)} columns")

    # Merge Bartik
    if bartik is not None:
        df = df.merge(bartik, on=['iso3', 'year'], how='left')
        print(f"  Merged Bartik instruments")

    # Merge treatment cohort info
    if cohorts is not None:
        cohort_cols = ['iso3', 'opening_year', 'opening_type', 'status',
                       'cohort_group', 'is_cca', 'is_baltic', 'is_cee',
                       'is_transition']
        available_cols = [c for c in cohort_cols if c in cohorts.columns]
        df = df.merge(cohorts[available_cols], on='iso3', how='left')
        print(f"  Merged treatment cohort assignments")

    # Merge additional WDI
    if wdi_additional is not None:
        wdi_cols = [c for c in wdi_additional.columns if c not in ['iso3', 'year']]
        df = df.merge(wdi_additional, on=['iso3', 'year'], how='left')
        print(f"  Merged additional WDI: {len(wdi_cols)} columns")

    # Construct event-time variable for staggered DiD
    if 'opening_year' in df.columns:
        df['event_time'] = df['year'] - df['opening_year']
        df['post_opening'] = (df['event_time'] >= 0).astype(float)
        df.loc[df['opening_year'].isna(), 'post_opening'] = np.nan
        df.loc[df['opening_year'].isna(), 'event_time'] = np.nan
        print(f"  Constructed event_time and post_opening variables")

    # Construct transition-economy indicators
    df['is_cca'] = df['iso3'].isin(CCA_COUNTRIES).astype(float)
    df['is_baltic'] = df['iso3'].isin(BALTIC_COUNTRIES).astype(float)
    df['is_transition'] = df['iso3'].isin(ALL_TRANSITION).astype(float)

    # Summary statistics
    print(f"\n  Final causal panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    print(f"  Year range: {df['year'].min()}-{df['year'].max()}")

    print(f"\n  Sample composition:")
    print(f"    CCA:         {df[df['is_cca']==1]['iso3'].nunique()} countries, "
          f"{(df['is_cca']==1).sum()} obs")
    print(f"    Baltic:      {df[df['is_baltic']==1]['iso3'].nunique()} countries, "
          f"{(df['is_baltic']==1).sum()} obs")
    print(f"    Transition:  {df[df['is_transition']==1]['iso3'].nunique()} countries, "
          f"{(df['is_transition']==1).sum()} obs")
    print(f"    Non-trans:   {df[df['is_transition']==0]['iso3'].nunique()} countries, "
          f"{(df['is_transition']==0).sum()} obs")

    if 'status' in df.columns:
        print(f"\n  Treatment status:")
        for st in ['always_open', 'opener', 'never_opened']:
            n_countries = df[df['status'] == st]['iso3'].nunique()
            n_obs = (df['status'] == st).sum()
            print(f"    {st}: {n_countries} countries, {n_obs} obs")

    # Variable coverage
    print(f"\n  Key variable coverage:")
    key_vars = ['ca_gdp', 'Z_1', 'Z_2', 'Z_3', 'kaopen',
                'fiscal_bal_gdp', 'nfa_gdp_lag', 'trade_openness',
                'Z_1_lag20', 'Z_1_lag25', 'Z_1_lag30',
                'Z_1_bartik', 'bartik_exposure',
                'control_corruption', 'tertiary_enrollment',
                'gross_savings_gni']
    for v in key_vars:
        if v in df.columns:
            n = df[v].notna().sum()
            print(f"    {v}: {n} obs ({100*n/len(df):.1f}%)")

    # Save
    df.to_csv(PROCESSED_DIR / "causal_panel.csv", index=False)
    print(f"\n  Saved causal panel: {df.shape}")
    print(f"  Location: {PROCESSED_DIR / 'causal_panel.csv'}")

    return df


# =====================================================================
# STEP 7: Diagnostic summary for CCA countries
# =====================================================================

def cca_diagnostics(df):
    """Print detailed diagnostics for CCA countries in the causal panel."""
    print("\n" + "=" * 70)
    print("STEP 7: CCA Diagnostic Summary")
    print("=" * 70)

    cca = df[df['is_cca'] == 1].copy()

    print(f"\n  CCA panel: {len(cca)} obs, {cca['iso3'].nunique()} countries")
    print(f"  Year range: {cca['year'].min()}-{cca['year'].max()}")

    # Per-country summary
    print(f"\n  Per-country summary:")
    print(f"  {'ISO3':5s} {'N':>4s} {'Years':>12s} {'CA mean':>8s} {'Z1 mean':>8s} "
          f"{'KAOPEN mn':>9s} {'Open yr':>8s} {'Status':>12s}")
    print("  " + "-" * 75)

    for iso3 in sorted(cca['iso3'].unique()):
        c = cca[cca['iso3'] == iso3]
        ca_mean = c['ca_gdp'].mean()
        z1_mean = c['Z_1'].mean()
        ka_mean = c['kaopen'].mean() if c['kaopen'].notna().any() else np.nan
        yr_range = f"{c['year'].min()}-{c['year'].max()}"

        open_yr = c['opening_year'].iloc[0] if 'opening_year' in c.columns and c['opening_year'].notna().any() else ''
        status = c['status'].iloc[0] if 'status' in c.columns and c['status'].notna().any() else ''

        print(f"  {iso3:5s} {len(c):4d} {yr_range:>12s} {ca_mean:8.2f} {z1_mean:8.3f} "
              f"{ka_mean:9.2f} {str(open_yr):>8s} {str(status):>12s}")

    # Instrument coverage for CCA
    print(f"\n  Lagged instrument coverage for CCA:")
    for lag in [20, 25, 30]:
        col = f'Z_1_lag{lag}'
        if col in cca.columns:
            n = cca[col].notna().sum()
            print(f"    {col}: {n}/{len(cca)} obs ({100*n/len(cca):.1f}%)")

    # Bartik coverage for CCA
    if 'Z_1_bartik' in cca.columns:
        n = cca['Z_1_bartik'].notna().sum()
        print(f"    Z_1_bartik: {n}/{len(cca)} obs ({100*n/len(cca):.1f}%)")

    # Pre vs post opening comparison for CCA openers
    if 'post_opening' in cca.columns:
        openers = cca[cca['opening_year'].notna()]
        if len(openers) > 0:
            print(f"\n  CCA openers: pre vs post opening")
            pre = openers[openers['post_opening'] == 0]
            post = openers[openers['post_opening'] == 1]
            print(f"    Pre-opening:  {len(pre)} obs, CA mean = {pre['ca_gdp'].mean():.2f}")
            print(f"    Post-opening: {len(post)} obs, CA mean = {post['ca_gdp'].mean():.2f}")

    return cca


# =====================================================================
# MAIN
# =====================================================================

if __name__ == '__main__':
    print("=" * 70)
    print("PHASE 1: DATA ASSEMBLY FOR CAUSAL IDENTIFICATION")
    print("=" * 70)

    # Step 1: Load existing data
    panel, demo_shares, demo_poly = load_existing_data()

    # Step 2: Construct lagged instruments
    instruments = construct_lagged_instruments(demo_poly)

    # Step 3: Identify treatment cohorts
    cohorts = identify_treatment_cohorts(panel)

    # Step 4: Download additional WDI data
    wdi_additional = download_additional_wdi()

    # Step 5: Construct Bartik instrument
    bartik = construct_bartik_instrument(demo_shares)

    # Step 6: Build integrated panel
    causal_panel = build_causal_panel(
        panel, instruments, bartik, cohorts, wdi_additional
    )

    # Step 7: CCA diagnostics
    cca_diagnostics(causal_panel)

    print("\n" + "=" * 70)
    print("PHASE 1 COMPLETE")
    print("=" * 70)
    print(f"\nOutput files:")
    for f in sorted(PROCESSED_DIR.glob("*.csv")):
        size_mb = f.stat().st_size / 1e6
        print(f"  {f.name}: {size_mb:.1f} MB")
